import matplotlib.pyplot as plt
import numpy as np


def print_results(title, errors):
    # Remove prediction singularities
    errors[errors == np.inf] = np.nan

    # Print baseline errors
    errors = np.nanmean(errors, axis=0)
    errors = np.concatenate((errors, np.mean(errors, axis=1, keepdims=True)), axis=1)
    print '\n\n' + title + ':'
    print 'MSE: ' + str(errors[0])
    print 'PSNR: ' + str(errors[1])
    print 'DSSIM: ' + str(errors[2])


def make_plot(measures, limits=None):
    def lineplot(y_label, measures, limits):
        lines = []
        for c, m, s in measures:
            lines += plt.plot(range(1, 11), m, s)
            plt.xlabel('time step')
            plt.ylabel(y_label)
            plt.xticks(range(1, 11))
            plt.xlim(1, 10)
            if limits[0] is not None:
                plt.ylim(limits[0], limits[1])
            plt.grid(True, linestyle='dashed')

        return lines

    # Set limits
    limits = [None] * 3 if limits is None else limits
    limits = [([[None, None]] * 3 if l is None else l) for l in limits]

    # Create plots
    fig = plt.figure(figsize=(16, 10), dpi=70)
    for i, (t_measures, t_limits) in enumerate(zip(measures, limits)):
        fig.add_subplot(3, 3, 3*i+1)
        lines = lineplot('Average MSE', t_measures[0], limits=t_limits[0])
        fig.add_subplot(3, 3, 3*i+2)
        lineplot('Average PSNR', t_measures[1], limits=t_limits[1])
        fig.add_subplot(3, 3, 3*i+3)
        lineplot('Average DSSIM', t_measures[2], limits=t_limits[2])

    # Display legend
    plt.subplots_adjust(top=1, bottom=0.09, left=0.05, right=0.99)
    labels = tuple(l for l, v, s in measures[0][0])
    plt.figlegend(lines, labels, loc='lower center', ncol=10, fontsize=12, frameon=False)

    # Display result
    # plt.tight_layout()
    plt.show()

if __name__ == '__main__':
    make_plot([([
        ('Baseline', [0.04923, 0.06121, 0.06642, 0.07147, 0.07338, 0.07512, 0.07511, 0.07567, 0.07542, 0.07590], 'k-'),
        ('RLadder', [0.04251, 0.04252, 0.04252, 0.04254, 0.04255, 0.04253, 0.04254, 0.04255, 0.04255, 0.04254], 'g-'),
        ('Prednet', [0.02838, 0.04408, 0.04285, 0.04285, 0.04270, 0.04255, 0.04258, 0.04255, 0.04256, 0.04254], 'r-'),
        ('Srivastava', [0.00885, 0.01097, 0.01308, 0.01496, 0.01680, 0.01856, 0.02035, 0.02185, 0.02335, 0.02473], 'y-'),
        ('Mathieu', [0.022462, 0.032085 , 0.037718 , 0.043201 , 0.043589  , 0.043213, 0.044363 , 0.045566, 0.046795, 0.047694], 'c-'),
        ('Villegas', [0.04251, 0.04252, 0.04252, 0.04254, 0.04255, 0.04253, 0.04254, 0.04255, 0.04255, 0.04255], 'm-'),
        ('fRNN', [0.00475, 0.00578, 0.00686, 0.00784, 0.00887, 0.00994, 0.01105, 0.01207, 0.01319, 0.01435], 'b-'),
        ('RLadder (pre-trained)', [0.00760, 0.00978, 0.01217, 0.01432, 0.01651, 0.01851, 0.02047, 0.02229, 0.02401, 0.02567], 'g--'),
    ], [
        ('Baseline', [13.233, 12.266, 11.937, 11.601, 11.513, 11.396, 11.407, 11.362, 11.388, 11.350], 'k-'),
        ('RLadder', [13.860, 13.859, 13.860, 13.858, 13.856, 13.858, 13.856, 13.855, 13.855, 13.856], 'g-'),
        ('Prednet', [15.684, 13.711, 13.828, 13.831, 13.843, 13.857, 13.853, 13.855, 13.855, 13.855], 'r-'),
        ('Srivastava', [20.809, 19.916, 19.177, 18.601, 18.103, 17.681, 17.276, 16.960, 16.671, 16.421], 'y-'),
        ('Mathieu', [16.4688, 14.9215, 14.2196, 13.6307, 13.5919, 13.6295, 13.5155, 13.3994, 13.2839, 13.2013], 'c-'),
        ('Villegas', [13.860, 13.859, 13.860, 13.858, 13.856, 13.858, 13.856, 13.855, 13.855, 13.856], 'm-'),
        ('fRNN', [24.208, 23.287, 22.566, 21.983, 21.455, 20.949, 20.471, 20.060, 19.634, 19.242], 'b-'),
        ('RLadder (pre-trained)', [21.703, 20.660, 19.674, 18.942, 18.291, 17.764, 17.291, 16.884, 16.531, 16.212], 'g--'),
    ], [
        ('Baseline', [0.15520, 0.17771, 0.19192, 0.20677, 0.21422, 0.22155, 0.22383, 0.22647, 0.22637, 0.22770], 'k-'),
        ('RLadder', [0.13797, 0.13776, 0.13783, 0.13785, 0.13780, 0.13777, 0.13789, 0.13799, 0.13802, 0.13791], 'g-'),
        ('Prednet', [0.11971, 0.16172, 0.15431, 0.14562, 0.14292, 0.13912, 0.13945, 0.13909, 0.13920, 0.13899], 'r-'),
        ('Srivastava', [0.05095, 0.05916, 0.06735, 0.07426, 0.08072, 0.08661, 0.09239, 0.09707, 0.10150, 0.10544], 'y-'),
        ('Mathieu', [0.1601, 0.2268, 0.2835, 0.3486, 0.3765, 0.4050, 0.4171, 0.4240, 0.4273, 0.4334], 'c-'),
        ('Villegas', [0.13905, 0.13885, 0.13891, 0.13894, 0.13889, 0.13886, 0.13898, 0.13908, 0.13910, 0.13899], 'm-'),
        ('fRNN', [0.02375, 0.02854, 0.03336, 0.03762, 0.04180, 0.04612, 0.05047, 0.05444, 0.05871, 0.06275], 'b-'),
        ('RLadder (pre-trained)', [0.03779, 0.04691, 0.05734, 0.06629, 0.07471, 0.08238, 0.08952, 0.09586, 0.10140, 0.10631], 'g--'),
    ]), ([
        ('Baseline', [0.00103, 0.00204, 0.00280, 0.00338, 0.00383, 0.00420, 0.00450, 0.00475, 0.00497, 0.00515], 'k-'),
        ('RLadder', [0.00080, 0.00064, 0.00087, 0.00110, 0.00132, 0.00151, 0.00169, 0.00184, 0.00199, 0.00213], 'g-'),
        ('Prednet', [0.00144, 0.00447, 0.00361, 0.00673, 0.00580, 0.00907, 0.00856, 0.01218, 0.01237, 0.01645], 'r-'),
        ('Srivastava', [0.00839, 0.00852, 0.00893, 0.00940, 0.00983, 0.01024, 0.01061, 0.01093, 0.01121, 0.01145], 'y-'),
        ('Mathieu', [0.0006567, 0.0010211, 0.0012615, 0.0014319, 0.0016421, 0.0017737, 0.0019078, 0.0021111, 0.0021855, 0.0023709], 'c-'),
        ('Villegas', [0.00030, 0.00063, 0.00098, 0.00132, 0.00161, 0.00189, 0.00214, 0.00234, 0.00254, 0.00274], 'm-'),
        ('fRNN', [0.00074, 0.00097, 0.00122, 0.00147, 0.00170, 0.00190, 0.00210, 0.00228, 0.00246, 0.00262], 'b-'),
        ('RLadder (pre-trained)', [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'g--'),
    ], [
        ('Baseline', [34.780, 31.774, 30.223, 29.233, 28.515, 27.973, 27.535, 27.180, 26.878, 26.619], 'k-'),
        ('RLadder', [31.737, 34.142, 33.018, 32.109, 31.393, 30.834, 30.375, 30.000, 29.677, 29.397], 'g-'),
        ('Prednet', [31.360, 26.974, 27.532, 24.952, 25.300, 23.329, 23.340, 21.776, 21.508, 20.280], 'r-'),
        ('Srivastava', [21.974, 21.922, 21.691, 21.439, 21.234, 21.048, 20.892, 20.762, 20.653, 20.559], 'y-'),
        ('Mathieu', [33.1342, 31.8160, 31.2525, 30.7705, 30.4912, 29.9523, 29.6754, 29.3361, 29.1516, 28.8458], 'c-'),
        ('Villegas', [37.575, 34.621, 32.709, 31.430, 30.401, 29.575, 28.940, 28.509, 28.061, 27.640], 'm-'),
        ('fRNN', [32.044, 31.106, 30.320, 29.683, 29.165, 28.749, 28.400, 28.097, 27.829, 27.596], 'b-'),
        ('RLadder (pre-trained)', [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'g--'),
    ], [
        ('Baseline', [0.02873, 0.04799, 0.06156, 0.07211, 0.08091, 0.08836, 0.09483, 0.10040, 0.10536, 0.10978], 'k-'),
        ('RLadder', [0.03249, 0.03745, 0.04533, 0.05268, 0.05916, 0.06473, 0.06965, 0.07395, 0.07780, 0.08125], 'g-'),
        ('Prednet', [0.04704, 0.09218, 0.08831, 0.12188, 0.12168, 0.15040, 0.15538, 0.18037, 0.19019, 0.21140], 'r-'),
        ('Srivastava', [0.18878, 0.18911, 0.19203, 0.19530, 0.19809, 0.20076, 0.20307, 0.20497, 0.20655, 0.20788], 'y-'),
        ('Mathieu', [0.0656, 0.0774, 0.0851, 0.0947, 0.0994, 0.1082, 0.1093, 0.1200, 0.1206, 0.1298], 'c-'),
        ('Villegas', [0.01778, 0.03261, 0.04741, 0.06162, 0.07656, 0.09009, 0.09973, 0.10550, 0.11346, 0.12094], 'm-'),
        ('fRNN', [0.04057, 0.05004, 0.05858, 0.06605, 0.07262, 0.07830, 0.08335, 0.08787, 0.09200, 0.09571], 'b-'),
        ('RLadder (pre-trained)', [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'g--'),
    ]), ([
        ('Baseline', [0.00412, 0.00751, 0.00992, 0.01179, 0.01332, 0.01456, 0.01566, 0.01665, 0.01753, 0.01830], 'k-'),
        ('RLadder', [0.00365, 0.00516, 0.00674, 0.00811, 0.00925, 0.01022, 0.01108, 0.01185, 0.01254, 0.01316], 'g-'),
        ('Prednet', [0.00274, 0.00878, 0.00874, 0.01523, 0.01589, 0.02313, 0.02436, 0.03307, 0.03527, 0.04516], 'r-'),
        ('Srivastava', [0.00908, 0.05399, 0.11943, 0.16735, 0.18014, 0.17885, 0.18194, 0.19184, 0.19989, 0.20404], 'y-'),
        ('Mathieu', [0.00646, 0.00708, 0.00869, 0.00875, 0.00861, 0.01042, 0.01210, 0.01252, 0.01475, 0.01773], 'c-'),
        ('Villegas', [0.00268, 0.00482, 0.00655, 0.00812, 0.00940, 0.01040, 0.01150, 0.01261, 0.01373, 0.01443], 'm-'),
        ('fRNN', [0.00274, 0.00481, 0.00652, 0.00795, 0.00920, 0.01029, 0.01122, 0.01204, 0.01273, 0.01334], 'b-'),
        ('RLadder (pre-trained)', [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'g--'),
    ], [
        ('Baseline', [28.993, 25.335, 23.812, 22.874, 22.230, 21.765, 21.362, 21.007, 20.722, 20.486], 'k-'),
        ('RLadder', [26.332, 25.979, 24.851, 23.992, 23.373, 22.893, 22.490, 22.154, 21.876, 21.641], 'g-'),
        ('Prednet', [29.537, 23.693, 23.612, 20.923, 20.554, 18.746, 18.403, 16.859, 16.477, 15.180], 'r-'),
        ('Srivastava', [21.077, 13.125, 9.876 , 8.515 , 8.212 , 8.212 , 8.094 , 7.849 , 7.670 , 7.577], 'y-'),
        ('Mathieu', [22.634, 22.406, 21.488, 21.674, 22.192, 21.262, 20.591, 20.699, 19.653, 18.804], 'c-'),
        ('Villegas', [29.389, 26.389, 24.759, 23.765, 22.959, 22.440, 21.854, 21.401, 20.940, 20.671], 'm-'),
        ('fRNN', [28.942, 26.411, 25.011, 24.086, 23.402, 22.878, 22.453, 22.111, 21.830, 21.593], 'b-'),
        ('RLadder (pre-trained)', [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'g--'),
    ], [
        ('Baseline', [0.06650, 0.10611, 0.12946, 0.14549, 0.15735, 0.16643, 0.17415, 0.18086, 0.18654, 0.19142], 'k-'),
        ('RLadder', [0.06874, 0.09301, 0.11197, 0.12665, 0.13804, 0.14714, 0.15475, 0.16118, 0.16664, 0.17136], 'g-'),
        ('Prednet', [0.05345, 0.12480, 0.12507, 0.17436, 0.17825, 0.21586, 0.21961, 0.25540, 0.26117, 0.29331], 'r-'),
        ('Srivastava', [0.13221, 0.34915, 0.41874, 0.46370, 0.47410, 0.47660, 0.47985, 0.48424, 0.48746, 0.48940], 'y-'),
        ('Mathieu', [0.0905, 0.0943, 0.1031, 0.1036, 0.1010, 0.1059, 0.1126, 0.1157, 0.1252, 0.1429], 'c-'),
        ('Villegas', [0.05502, 0.09051, 0.11414, 0.13304, 0.14620, 0.15725, 0.16712, 0.17599, 0.18492, 0.19083], 'm-'),
        ('fRNN', [0.05446, 0.08526, 0.10694, 0.12300, 0.13569, 0.14580, 0.15421, 0.16115, 0.16700, 0.17195], 'b-'),
        ('RLadder (pre-trained)', [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'g--'),
    ])], limits=(
        None,
        [[0, 0.006], [20, 38], [0.01, 0.12]],
        [[0.0025, 0.02], [20, 30], [0.05, 0.20]],
    ))
